package com.yeskery.nut.websocket;

import com.yeskery.nut.bean.ApplicationContext;
import com.yeskery.nut.bean.BeanIterable;
import net.sf.cglib.proxy.Callback;
import net.sf.cglib.proxy.Enhancer;
import net.sf.cglib.proxy.InvocationHandler;

import javax.websocket.DeploymentException;
import javax.websocket.Endpoint;
import javax.websocket.server.ServerContainer;
import javax.websocket.server.ServerEndpoint;
import javax.websocket.server.ServerEndpointConfig;
import java.util.Collection;
import java.util.Collections;
import java.util.stream.Collectors;

/**
 * Javax WebSocket加载器
 * @author sprout
 * @version 1.0
 * 2023-04-16 18:48
 */
public class JavaxWebSocketLoader extends ApplicationWebSocketLoader {

    /**
     * 构建Javax WebSocket加载器
     * @param applicationContext 应用上下文
     */
    public JavaxWebSocketLoader(ApplicationContext applicationContext) {
        super(applicationContext);
    }

    /**
     * 注册WebSocket服务端点
     * @param serverContainer websocket服务容器
     * @param webSocketConfigurations WebSocket配置对象集合
     */
    @Override
    public void registerEndpoints(ServerContainer serverContainer, Collection<WebSocketConfiguration> webSocketConfigurations) {
        for (WebSocketConfiguration webSocketConfiguration : webSocketConfigurations) {
            Enhancer enhancer = new Enhancer();
            enhancer.setSuperclass(Endpoint.class);
            InvocationHandler invocationHandler = new WebSocketCglibInvocationHandler(webSocketConfiguration);
            enhancer.setCallback(invocationHandler);
            Class<?> generatedClass = enhancer.create().getClass();
            Enhancer.registerStaticCallbacks(generatedClass, new Callback[]{invocationHandler});
            try {
                serverContainer.addEndpoint(ServerEndpointConfig.Builder.create(generatedClass,
                        webSocketConfiguration.getWebSocketServerConfigure().getPath()).build());
            } catch (DeploymentException e) {
                throw new WebSocketException("WebSocket ServerEndpointConfig Deploy Fail.", e);
            }
        }

        // load javax.websocket ServerEndpoint
        for (Class<?> endpointBeanClass : getJavaxWebsocketServerEndpointBeanClasses()) {
            try {
                serverContainer.addEndpoint(endpointBeanClass);
            } catch (DeploymentException e) {
                throw new WebSocketException("WebSocket ServerEndpoint Class Deploy Fail.", e);
            }
        }
    }

    /**
     * 获取Javax WebSocket 服务端点类对象
     * @return Javax WebSocket 服务端点类对象
     */
    protected Collection<Class<?>> getJavaxWebsocketServerEndpointBeanClasses() {
        ApplicationContext applicationContext = getApplicationContext();
        if (applicationContext instanceof BeanIterable) {
            return ((BeanIterable) applicationContext).getSingletonBeans().values()
                    .stream()
                    .map(Object::getClass)
                    .filter(c -> c.isAnnotationPresent(ServerEndpoint.class))
                    .collect(Collectors.toList());
        }
        return Collections.emptySet();
    }

    /**
     * 获取Javax WebSocket 服务端点bean对象
     * @return Javax WebSocket 服务端点bean对象
     */
    protected Collection<Object> getJavaxWebsocketServerEndpointBeans() {
        ApplicationContext applicationContext = getApplicationContext();
        if (applicationContext instanceof BeanIterable) {
            return ((BeanIterable) applicationContext).getSingletonBeans().values()
                    .stream()
                    .filter(o -> o.getClass().isAnnotationPresent(ServerEndpoint.class))
                    .collect(Collectors.toList());
        }
        return Collections.emptySet();
    }
}
